Skip to content

feat: add Qwen3.5 35B A3B model support on TRN2#48

Open
YantaoShen wants to merge 4 commits intoaws-neuron:mainfrom
YantaoShen:feat/qwen3.5-35b-a3b
Open

feat: add Qwen3.5 35B A3B model support on TRN2#48
YantaoShen wants to merge 4 commits intoaws-neuron:mainfrom
YantaoShen:feat/qwen3.5-35b-a3b

Conversation

@YantaoShen
Copy link
Copy Markdown

Issue #, if available:

Description of changes:

feat: add Qwen3.5 35B A3B model support on TRN2

Add inference support for Qwen3.5-35B-A3B (MoE) on AWS Trainium2.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@YantaoShen YantaoShen requested a review from a team March 30, 2026 23:07
all_logits = local_logits

# Argmax on CPU
next_id = all_logits.argmax(dim=-1, keepdim=True).to(dtype=torch.int) # (B, 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure this is properly fixed before merging

YantaoShen and others added 3 commits April 17, 2026 15:45
…-topk

  Replaces the CPU argmax path in _sample_token with a single greedy_sampling
  kernel that does RMSNorm + lm_head matmul + all_gather(full logits) + global
  topk entirely on device. Per-step DtoH is now (B,) uint32 = 4 bytes instead
  of (B, vocab_per_device) f32 = ~248 KB, and gloo all_gather + torch argmax
  on CPU are both eliminated.

  Kernel graph is gather-then-topk, not the usual topk-then-gather-then-index:
  neuronx-cc 2.23 miscompiles topk when all_gather is downstream in the same
  kernel for certain vocab_per_device sizes (including 62080 used at TP=4),
  producing wrong token IDs. Keeping topk strictly downstream of all_gather
  sidesteps this, and taking topk over the gathered full-vocab logits directly
  returns the global winner ID -- no rank-offset arithmetic, no dynamic index.

  - kernels/sampling.py: replace compute_logits with greedy_sampling
  - qwen3_5.py: compile greedy_sampling, swap (B, vpd) f32 logits buffers
    for (B,) uint32 next_id buffers, simplify _sample_token to a single
    DtoH + reshape

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants